Skip to main content

6.1 - Code - The Reusable train.py Script

This train.py script serves as the main entry point for our entire reinforcement learning application. It is the top-level, synchronous process that orchestrates the creation of the learning environment, the training of the agent, and the saving of the final model.

The script is designed to be reusable; by changing a single configuration variable, you can switch between training the WorkerEnv, MacroEnv, or MicroEnv.

The Training Workflow

This script follows the standard machine learning training pipeline.

+--------------------------+
| 1. Select & Import Env |
+--------------------------+
|
v
+--------------------------+
| 2. Instantiate Env |
| (Vectorized) |
+--------------------------+
|
v
+--------------------------+
| 3. Instantiate RL Model |
| (e.g., PPO) |
+--------------------------+
|
v
+--------------------------+
| 4. Run Training Loop |
| (model.learn) |
+--------------------------+
|
v
+--------------------------+
| 5. Save Trained Model |
+--------------------------+

The Code: train.py

# train.py

from stable_baselines3 import PPO
from stable_baselines3.common.env_util import make_vec_env

# --- Step 1: Import all possible environments ---
# These custom environment classes must follow the Gymnasium API specification.
from worker_bot import WorkerEnv
from macro_bot import MacroEnv
from micro_bot import MicroEnv

# --- Configuration: Select the environment to train ---
# Use a dictionary to map names to environment classes for clean selection.
ENVIRONMENTS = {
"worker": WorkerEnv,
"macro": MacroEnv,
"micro": MicroEnv,
}
# CHANGE THIS VALUE to select which environment to train
SELECTED_ENV = "micro"


def main():
"""The main function to configure and run the training process."""

print(f"--- Starting training for environment: {SELECTED_ENV.upper()} ---")

# --- Step 2: Instantiate the selected environment ---
# We use make_vec_env to create a vectorized environment.
# This is the standard in Stable Baselines3 for robust environment handling.
# It runs the environment in a separate process, which is required for StarCraft II.
env_class = ENVIRONMENTS[SELECTED_ENV]
env = make_vec_env(env_class, n_envs=1)

# --- Step 3: Instantiate the PPO model ---
# Proximal Policy Optimization (PPO) is a robust, general-purpose RL algorithm.
# "MlpPolicy" specifies that the agent's brain will be a Multi-Layer Perceptron network.
model = PPO(
"MlpPolicy",
env,
verbose=1,
tensorboard_log="./sc2_rl_tensorboard/"
)

# --- Step 4: Start the training loop ---
# total_timesteps is the number of agent-environment interactions to perform.
print(f"--- Beginning training for {100_000:,} timesteps... ---")
model.learn(total_timesteps=100_000, progress_bar=True)
print("--- Training complete. ---")

# --- Step 5: Save the trained model ---
# The learned policy is saved so it can be loaded later for evaluation.
model_path = f"ppo_sc2_{SELECTED_ENV}.zip"
model.save(model_path)
print(f"--- Model saved to {model_path} ---")

# --- Final Step: Clean up the environment ---
# This ensures the vectorized environment and the StarCraft II game process are terminated.
env.close()

if __name__ == '__main__':
# This guard is necessary for multiprocessing, which `make_vec_env` uses.
# It prevents child processes from re-importing and re-executing the main script,
# which would lead to an infinite loop of process creation on some platforms.
main()